EchoSpike¶

In [ ]:
import matplotlib.pyplot as plt
from tqdm.notebook import trange
from utils import test
from data import load_classwise_PMNIST, load_classwise_NMNIST
from model import EchoSpike, simple_out
import numpy as np
import torch
import seaborn as sns
from scipy.signal import savgol_filter
import pickle
from main import Args
from matplotlib import pyplot
pyplot.rcParams['figure.dpi'] = 600

color_list = sns.color_palette('muted')
device = 'cpu'
folder = 'models/'
model_name = folder + 'online_nmnist.pt'
# with open(model_name[:-3] + '_args.pkl', 'rb') as f:
#     args = pickle.load(f)
#     args.device = device
args = Args()
print(vars(args))
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
  warnings.warn(_BETA_TRANSFORMS_WARNING)
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
  warnings.warn(_BETA_TRANSFORMS_WARNING)
{'model_name': 'test', 'dataset': 'nmnist', 'online': True, 'device': 'cpu', 'recurrency_type': 'none', 'lr': 0.0001, 'epochs': 100, 'augment': True, 'batch_size': 128, 'n_hidden': [200, 200, 200], 'c_y': [2, -1], 'inp_thr': 0.02, 'n_inputs': 2312, 'n_outputs': 10, 'n_time_bins': 10, 'beta': 0.9}

Dataset¶

N-MNIST

In [ ]:
if args.dataset == 'mnist':
    train_loader, train_loader2, test_loader = load_classwise_PMNIST(args.n_time_bins, scale=args.poisson_scale, split_train=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
else:
    train_loader, train_loader2, test_loader = load_classwise_NMNIST(args.n_time_bins, split_train=True)
# Plot Example
frames, target = train_loader.next_item(-1)
print(frames.shape, f'Target Digit: {target.item()}')
plt.figure()
if args.dataset == 'mnist':
    plt.imshow(frames[0].view(28,28), cmap='gray')
else:
    plt.imshow(frames.squeeze().sum(axis=0).view(2,34,34)[0], cmap='gray')
(34, 34, 2)
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.y = torch.tensor(y)
torch.Size([10, 1, 2312]) Target Digit: 4

Load pretrained model¶

In [ ]:
SNN = EchoSpike(args.n_inputs, args.n_hidden, beta=args.beta, recurrency_type=args.recurrency_type).to(device)
# state_dict = torch.load(model_name, map_location=args.device)
# state_dict = {key.replace('clapp', 'layers'):value for key, value in state_dict.items()}
# # overwrite the state dict
# torch.save(state_dict, model_name)
SNN.load_state_dict(torch.load(model_name, map_location=device))
# Load and Plot train loss history
echo_train_losses = torch.load(f'{model_name[:-3]}_loss_hist.pt', map_location=device)
for i in range(echo_train_losses.shape[1]):
    plt.plot(np.linspace(0, args.epochs, len(echo_train_losses)), savgol_filter(echo_train_losses[:,i], 99, 1), label=f'Layer {i+1}', color=color_list[i])
plt.ylabel('EchoSpike Loss')
# no y ticks, because it's not really meaningful
plt.yticks([])
plt.xlabel('Epoch')
# plt.title('EchoSpike Loss During Training for Each Layer');
plt.legend();

Run EchoSpike on the test set and get hidden states¶

In [ ]:
echo_activation, target_list, echo_losses = test(SNN, test_loader, device, batch_size=args.batch_size)
print(f'EchoSpike loss per layer: {torch.stack(echo_losses).mean(axis=0).numpy()}')
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 echo_activation, target_list, echo_losses = test(SNN, test_loader, device, batch_size=args.batch_size)
      2 print(f'EchoSpike loss per layer: {torch.stack(echo_losses).mean(axis=0).numpy()}')

File ~/ownCloud/ETH/Master/Project_2/SNN_CLAPP/utils.py:96, in test(net, testloader, device, batch_size)
     94 target = [torch.randint(testloader.num_classes, (1,)).item() for _ in range(batch_size)]
     95 while True:
---> 96     data, target = testloader.next_item(target, contrastive=(bf==-1))
     97     target_list.append(target)
     98     data = data.float().to(device)

File ~/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:69, in classwise_loader.next_item(self, target, contrastive)
     67 targets = []
     68 for i in indeces:
---> 69     im, t = self.data[i]
     70     imgs.append(torch.tensor(im).view(im.shape[0], -1))
     71     targets.append(t)

File ~/miniconda3/lib/python3.9/site-packages/tonic/cached_dataset.py:137, in DiskCachedDataset.__getitem__(self, item)
    135 file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5")
    136 try:
--> 137     data, targets = load_from_disk_cache(file_path)
    138 except (FileNotFoundError, OSError) as _:
    139     logging.info(
    140         f"Data {item}: {file_path} not in cache, generating it now",
    141         stacklevel=2,
    142     )

File ~/miniconda3/lib/python3.9/site-packages/tonic/cached_dataset.py:221, in load_from_disk_cache(file_path)
    216                 data = {
    217                     key: f[f"{name}/{index}/{key}"][()]
    218                     for key in f[f"{name}/{index}"].keys()
    219                 }
    220             else:
--> 221                 data = f[f"{name}/{index}"][()]
    222             _list.append(data)
    223 if len(data_list) == 1:

File h5py/_objects.pyx:54, in h5py._objects.with_phil.wrapper()

File h5py/_objects.pyx:55, in h5py._objects.with_phil.wrapper()

File ~/miniconda3/lib/python3.9/site-packages/h5py/_hl/dataset.py:790, in Dataset.__getitem__(self, args, new_dtype)
    787     return self.fields(names, _prior_dtype=new_dtype)[args]
    789 if new_dtype is None:
--> 790     new_dtype = self.dtype
    791 mtype = h5t.py_create(new_dtype)
    793 # === Special-case region references ====

File h5py/_objects.pyx:54, in h5py._objects.with_phil.wrapper()

File h5py/_objects.pyx:55, in h5py._objects.with_phil.wrapper()

File ~/miniconda3/lib/python3.9/site-packages/h5py/_hl/dataset.py:541, in Dataset.dtype(self)
    538         self._cache_props['_fast_reader'] = rdr
    539     return rdr
--> 541 @property
    542 @with_phil
    543 def dtype(self):
    544     """Numpy dtype representing the datatype"""
    545     return self.id.dtype

KeyboardInterrupt: 

Analyze Weights Directly¶

In [ ]:
layers = [SNN.layers[0].fc.weight]
for i in range(1, len(SNN.layers)):
    layers.append(SNN.layers[i].fc.weight @ layers[-1])

for i in range(len(SNN.layers)):
    plt.figure()
    plt.title(f'Layer {i}, Forward weights')
    plt.imshow(SNN.layers[i].fc.weight.detach())
    plt.colorbar()
for lidx, lay in enumerate(layers):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Receptive field, Layer {lidx}')
    for i in range(3):
        if args.dataset == 'mnist':
            axs[i].imshow(lay[i].view(28, 28).detach())
        else:
            axs[i].imshow(lay[i].view(2, 34, 34)[0].detach())

Plot the projections of the hidden states¶

In [ ]:
print(len(echo_activation))
hidden_activities_transformed = [[] for i in range(len(args.n_hidden))]
for ca in echo_activation:
    for ca_layer in range(len(ca)):
        hidden_activities_transformed[ca_layer].append(ca[ca_layer])
for ha_idx in range(len(args.n_hidden)):
    hidden_activities_transformed[ha_idx] = torch.stack(hidden_activities_transformed[ha_idx]).reshape(-1, hidden_activities_transformed[ha_idx][0].shape[-1])

target_transformed = torch.stack(target_list).flatten()
print(hidden_activities_transformed[0].shape, target_transformed.shape)

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

# transform = TSNE()
# transform = PCA()
transform = UMAP()
colors = [color_list[i.int()] for i in target_transformed]

for hat in hidden_activities_transformed:
    # Number of Neurons that never spiked during the test set
    print(f'{(hat.sum(axis=0) == 0).sum()} dead neurons')
    hat_transform = transform.fit_transform(hat.detach().cpu().numpy())
    plt.figure(figsize=(8,8))
    col = colors
    # Plot each digit separately, this makes it easier to color and label them
    for i in range(args.n_outputs):
        col_indeces = np.argwhere(target_transformed.squeeze() == i).squeeze()
        hattt = hat_transform[col_indeces, :]
        plt.scatter(hattt[:,0], hattt[:,1], s=6, color=color_list[i], label=i, alpha=0.4)
    plt.legend()
79
torch.Size([10112, 200]) torch.Size([10112])
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [6], in <cell line: 14>()
     12 from sklearn.decomposition import PCA
     13 from sklearn.manifold import TSNE
---> 14 from umap import UMAP
     16 # transform = TSNE()
     17 # transform = PCA()
     18 transform = UMAP()

File ~/miniconda3/lib/python3.9/site-packages/umap/__init__.py:2, in <module>
      1 from warnings import warn, catch_warnings, simplefilter
----> 2 from .umap_ import UMAP
      4 try:
      5     with catch_warnings():

File ~/miniconda3/lib/python3.9/site-packages/umap/umap_.py:48, in <module>
     41 from umap.spectral import spectral_layout, tswspectral_layout
     42 from umap.layouts import (
     43     optimize_layout_euclidean,
     44     optimize_layout_generic,
     45     optimize_layout_inverse,
     46 )
---> 48 from pynndescent import NNDescent
     49 from pynndescent.distances import named_distances as pynn_named_distances
     50 from pynndescent.sparse import sparse_named_distances as pynn_sparse_named_distances

File ~/miniconda3/lib/python3.9/site-packages/pynndescent/__init__.py:5, in <module>
      1 import sys
      3 import numba
----> 5 from .pynndescent_ import NNDescent, PyNNDescentTransformer
      7 if sys.version_info[:2] >= (3, 8):
      8     import importlib.metadata as importlib_metadata

File ~/miniconda3/lib/python3.9/site-packages/pynndescent/pynndescent_.py:22, in <module>
     12 from scipy.sparse import (
     13     csr_matrix,
     14     coo_matrix,
   (...)
     17     issparse,
     18 )
     20 import heapq
---> 22 import pynndescent.sparse as sparse
     23 import pynndescent.sparse_nndescent as sparse_nnd
     24 import pynndescent.distances as pynnd_dist

File ~/miniconda3/lib/python3.9/site-packages/pynndescent/sparse.py:519, in <module>
    502     else:
    503         return float(num_non_zero - num_equal) / num_non_zero
    506 @numba.njit(
    507     [
    508         "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
    509         numba.types.float32(
    510             numba.types.Array(numba.types.int32, 1, "C", readonly=True),
    511             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
    512             numba.types.Array(numba.types.int32, 1, "C", readonly=True),
    513             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
    514         ),
    515     ],
    516     fastmath=True,
    517     locals={"num_non_zero": numba.types.intp, "num_equal": numba.types.intp},
    518 )
--> 519 def sparse_alternative_jaccard(ind1, data1, ind2, data2):
    520     num_equal = fast_intersection_size(ind1, ind2)
    521     num_non_zero = ind1.shape[0] + ind2.shape[0] - num_equal

File ~/miniconda3/lib/python3.9/site-packages/numba/core/decorators.py:241, in _jit.<locals>.wrapper(func)
    239     with typeinfer.register_dispatcher(disp):
    240         for sig in sigs:
--> 241             disp.compile(sig)
    242         disp.disable_compile()
    243 return disp

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:965, in Dispatcher.compile(self, sig)
    963 with ev.trigger_event("numba:compile", data=ev_details):
    964     try:
--> 965         cres = self._compiler.compile(args, return_type)
    966     except errors.ForceLiteralArg as e:
    967         def folded(args, kws):

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:125, in _FunctionCompiler.compile(self, args, return_type)
    124 def compile(self, args, return_type):
--> 125     status, retval = self._compile_cached(args, return_type)
    126     if status:
    127         return retval

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type)
    136     pass
    138 try:
--> 139     retval = self._compile_core(args, return_type)
    140 except errors.TypingError as e:
    141     self._failed_cache[key] = e

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type)
    149 flags = self._customize_flags(flags)
    151 impl = self._get_implementation(args, {})
--> 152 cres = compiler.compile_extra(self.targetdescr.typing_context,
    153                               self.targetdescr.target_context,
    154                               impl,
    155                               args=args, return_type=return_type,
    156                               flags=flags, locals=self.locals,
    157                               pipeline_class=self.pipeline_class)
    158 # Check typing error if object mode is used
    159 if cres.typing_error is not None and not flags.enable_pyobject:

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:762, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    738 """Compiler entry point
    739 
    740 Parameter
   (...)
    758     compiler pipeline
    759 """
    760 pipeline = pipeline_class(typingctx, targetctx, library,
    761                           args, return_type, flags, locals)
--> 762 return pipeline.compile_extra(func)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:460, in CompilerBase.compile_extra(self, func)
    458 self.state.lifted = ()
    459 self.state.lifted_from = None
--> 460 return self._compile_bytecode()

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:528, in CompilerBase._compile_bytecode(self)
    524 """
    525 Populate and run pipeline for bytecode input
    526 """
    527 assert self.state.func_ir is None
--> 528 return self._compile_core()

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:494, in CompilerBase._compile_core(self)
    492 res = None
    493 try:
--> 494     pm.run(self.state)
    495     if self.state.cr is not None:
    496         break

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/miniconda3/lib/python3.9/site-packages/numba/core/typed_passes.py:468, in BaseNativeLowering.run_pass(self, state)
    466 lower.lower()
    467 if not flags.no_cpython_wrapper:
--> 468     lower.create_cpython_wrapper(flags.release_gil)
    470 if not flags.no_cfunc_wrapper:
    471     # skip cfunc wrapper generation if unsupported
    472     # argument or return types are used
    473     for t in state.args:

File ~/miniconda3/lib/python3.9/site-packages/numba/core/lowering.py:297, in BaseLower.create_cpython_wrapper(self, release_gil)
    292 if self.genlower:
    293     self.context.create_cpython_wrapper(self.library,
    294                                         self.genlower.gendesc,
    295                                         self.env, self.call_helper,
    296                                         release_gil=release_gil)
--> 297 self.context.create_cpython_wrapper(self.library, self.fndesc,
    298                                     self.env, self.call_helper,
    299                                     release_gil=release_gil)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/cpu.py:191, in CPUContext.create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
    187 builder = PyCallWrapper(self, wrapper_module, wrapper_callee,
    188                         fndesc, env, call_helper=call_helper,
    189                         release_gil=release_gil)
    190 builder.build()
--> 191 library.add_ir_module(wrapper_module)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/codegen.py:730, in CPUCodeLibrary.add_ir_module(self, ir_module)
    728 ll_module.name = ir_module.name
    729 ll_module.verify()
--> 730 self.add_llvm_module(ll_module)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/codegen.py:737, in CPUCodeLibrary.add_llvm_module(self, ll_module)
    735 if not config.LLVM_REFPRUNE_PASS:
    736     ll_module = remove_redundant_nrt_refct(ll_module)
--> 737 self._final_module.link_in(ll_module)

File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/module.py:174, in ModuleRef.link_in(self, other, preserve)
    172 if preserve:
    173     other = other.clone()
--> 174 link_modules(self, other)

File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/linker.py:7, in link_modules(dst, src)
      5 def link_modules(dst, src):
      6     with ffi.OutputString() as outerr:
----> 7         err = ffi.lib.LLVMPY_LinkModules(dst, src, outerr)
      8         # The underlying module was destroyed
      9         src.detach()

File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/ffi.py:152, in _lib_fn_wrapper.__call__(self, *args, **kwargs)
    150 def __call__(self, *args, **kwargs):
    151     with self._lock:
--> 152         return self._cfn(*args, **kwargs)

KeyboardInterrupt: 

Train output Projection From Each layer and from Inputs directly¶

In [ ]:
from tqdm.notebook import tqdm

def train_out_proj(epochs, batch, out_projs=None, cat=False):
    # train output projections from all layers (and no layer)
    dataloader = train_loader2
    losses_out = []
    optimizers = []
    print_interval = 40*batch
    if out_projs is None:
        out_projs = []
        out_proj_0 = simple_out(args.n_inputs, args.n_outputs, beta=1.0)
    else:
        for out_p in out_projs:
            out_p.train()
            out_p.reset()
        out_proj_0 = out_projs[0]
        out_projs = out_projs[1:]
    optim_0 = torch.optim.Adam(out_proj_0.parameters(), lr=1e-2)
    for lay in range(len(SNN.layers)):
        if len(out_projs) <= lay:
            if cat:
                out_projs.append(simple_out(sum(args.n_hidden[:lay+1]) + args.n_inputs, args.n_outputs, beta=1.0))
            else:
                out_projs.append(simple_out(args.n_hidden[lay], args.n_outputs, beta=1.0))
        optimizers.append(torch.optim.Adam(out_projs[lay].parameters(), lr=1e-2))
        optimizers[-1].zero_grad()
    SNN.eval()
    target = batch*[0]
    acc = []
    correct = (len(SNN.layers) + 1)*[0]
    with torch.no_grad():
        pbar = tqdm(total=len(dataloader)*epochs)
        while len(losses_out)*batch < len(dataloader)*epochs:
            data, target = dataloader.next_item(target, contrastive=True)
            SNN.reset(0)
            logit_lists = [[] for lay in range(len(SNN.layers)+1)]
            data = data.squeeze()
            for step in range(data.shape[0]):
                data_step = data[step].float().to(device)
                target = target.to(device)
                logits, _, _ = SNN(data_step, 0)
                if step == args.n_time_bins-1:
                    _, logts = out_proj_0(data_step, target)
                    logit_lists[0] = logts
                    for lay in range(len(SNN.layers)):
                        if cat:
                            _, logts = out_projs[lay](torch.cat([data_step, *logits[:lay+1]], dim=-1), target)
                        else:
                            _, logts = out_projs[lay](logits[lay], target)
                        logit_lists[lay+1] = logts
                else:
                    out_proj_0(data_step, None)
                    for lay in range(len(SNN.layers)):
                        if cat:
                            out_projs[lay](torch.cat([data_step, *logits[:lay+1]], dim=-1), None)
                        else:
                            out_projs[lay](logits[lay], None)
            
            preds = [logit_lists[lay].argmax(axis=-1) for lay in range(len(SNN.layers)+1)]
            # if pred.max() < 1: print(pred.max())
            dL = [preds[lay] == target for lay in range(len(SNN.layers)+1)]
            correct = [correct[lay] + dL[lay].sum() for lay in range(len(SNN.layers)+1)]
            out_proj_0.reset()
            for i, out_proj in enumerate(out_projs):
                out_proj.reset()

            losses_out.append(torch.tensor([torch.nn.functional.cross_entropy(logit_lists[lay], target.squeeze().long()) for lay in range(len(SNN.layers)+1)], requires_grad=False))

            optim_0.step()
            optim_0.zero_grad()
            for opt in optimizers:
                opt.step()
                opt.zero_grad()
            
            if len(losses_out)*batch % print_interval == 0:
                pbar.write(f'Cross Entropy Loss: {(torch.stack(losses_out)[-print_interval//batch:].sum(dim=0)/(print_interval//batch)).numpy()}\n' +
                           f'Correct: {100*np.array(correct)/print_interval}%')
                acc.append(np.array(correct)/print_interval)
                correct = (len(SNN.layers) + 1)*[0]
            pbar.update(batch)
    return [out_proj_0, *out_projs], np.asarray(acc), torch.stack(losses_out)

with torch.no_grad():
    cat = False
    # repeat 10 times
    test_accs = []
    train_accs = []
    for i in range(10):
        # new random seed
        torch.manual_seed(i)
        out_projs, acc, losses_out = train_out_proj(1, 30, cat=cat)
        test_accs.append(get_accuracy(test_loader, out_projs, cat=cat)[0])
        train_accs.append(get_accuracy(train_loader2, out_projs, cat=cat)[0])
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.3382182  0.47337666 0.42744297 0.37736744]
Correct: [71.83333333 89.16666667 88.91666667 90.75      ]%
Cross Entropy Loss: [1.6546555  0.17299062 0.17895769 0.16855808]
Correct: [79.91666667 95.16666667 95.75       96.16666667]%
Cross Entropy Loss: [1.2595325  0.15038992 0.16651992 0.16367355]
Correct: [84.83333333 95.16666667 95.5        95.83333333]%
Cross Entropy Loss: [1.2586384  0.1645219  0.17492877 0.17914307]
Correct: [85.58333333 95.25       95.33333333 95.33333333]%
Cross Entropy Loss: [1.553818   0.15008847 0.15458283 0.16141796]
Correct: [86.         96.16666667 96.25       96.25      ]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 86.06%
From layer 1:
Accuracy: 95.28%
From layer 2:
Accuracy: 95.44%
From layer 3:
Accuracy: 95.25%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 87.02%
From layer 1:
Accuracy: 96.67%
From layer 2:
Accuracy: 96.37%
From layer 3:
Accuracy: 96.00%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.9430552  0.48246732 0.39133295 0.38660818]
Correct: [66.25 87.   89.75 90.5 ]%
Cross Entropy Loss: [1.4480274  0.19134916 0.201902   0.20016089]
Correct: [84.25       94.         94.91666667 94.83333333]%
Cross Entropy Loss: [1.0512209  0.13377151 0.1263952  0.12791191]
Correct: [85.91666667 96.25       96.83333333 96.58333333]%
Cross Entropy Loss: [1.3005296  0.1641307  0.18072578 0.1902675 ]
Correct: [85.33333333 95.08333333 95.16666667 95.16666667]%
Cross Entropy Loss: [1.1286719  0.1699802  0.1810993  0.17415996]
Correct: [85.08333333 95.33333333 95.83333333 95.66666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 80.96%
From layer 1:
Accuracy: 95.27%
From layer 2:
Accuracy: 95.24%
From layer 3:
Accuracy: 95.23%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 79.73%
From layer 1:
Accuracy: 96.52%
From layer 2:
Accuracy: 96.37%
From layer 3:
Accuracy: 96.02%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.1067567  0.4774762  0.38488543 0.3602371 ]
Correct: [66.58333333 88.91666667 90.08333333 91.08333333]%
Cross Entropy Loss: [1.0771862  0.15314627 0.16538922 0.16365175]
Correct: [85.83333333 95.58333333 95.5        95.58333333]%
Cross Entropy Loss: [1.2817585  0.18245208 0.18243818 0.17942084]
Correct: [84.83333333 95.08333333 95.33333333 94.91666667]%
Cross Entropy Loss: [1.6821384  0.14353484 0.16773058 0.1662912 ]
Correct: [81.91666667 95.58333333 95.5        95.25      ]%
Cross Entropy Loss: [1.0285895  0.16867764 0.17331538 0.17302595]
Correct: [86.5        95.66666667 95.25       95.58333333]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 88.08%
From layer 1:
Accuracy: 95.70%
From layer 2:
Accuracy: 95.85%
From layer 3:
Accuracy: 95.63%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 90.65%
From layer 1:
Accuracy: 96.98%
From layer 2:
Accuracy: 96.23%
From layer 3:
Accuracy: 95.92%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7217298  0.46089298 0.34992224 0.3748212 ]
Correct: [69.33333333 87.58333333 91.16666667 90.25      ]%
Cross Entropy Loss: [1.4026086  0.15913293 0.15897939 0.16435106]
Correct: [83.16666667 95.25       95.75       95.58333333]%
Cross Entropy Loss: [1.5430806  0.15613322 0.15375328 0.15590101]
Correct: [82.41666667 95.5        96.         95.91666667]%
Cross Entropy Loss: [1.5777922  0.18366615 0.18653242 0.18294129]
Correct: [83.08333333 95.         95.25       95.33333333]%
Cross Entropy Loss: [1.2560295  0.1496404  0.16298625 0.15777263]
Correct: [87.41666667 96.08333333 95.91666667 95.66666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 86.61%
From layer 1:
Accuracy: 94.98%
From layer 2:
Accuracy: 95.14%
From layer 3:
Accuracy: 95.21%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 87.82%
From layer 1:
Accuracy: 96.52%
From layer 2:
Accuracy: 96.22%
From layer 3:
Accuracy: 95.83%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.9129624  0.43650824 0.34120786 0.36636773]
Correct: [65.08333333 88.83333333 90.41666667 90.41666667]%
Cross Entropy Loss: [1.699683   0.1648637  0.16833332 0.16378531]
Correct: [82.75       95.5        95.66666667 95.5       ]%
Cross Entropy Loss: [1.3655012  0.16319947 0.17740284 0.18934996]
Correct: [86.         95.41666667 95.25       95.25      ]%
Cross Entropy Loss: [1.4213681  0.19052967 0.20100173 0.2064105 ]
Correct: [83.41666667 94.66666667 94.33333333 94.5       ]%
Cross Entropy Loss: [1.6758854  0.14381635 0.15897633 0.15392265]
Correct: [83.83333333 95.91666667 96.16666667 95.41666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 86.64%
From layer 1:
Accuracy: 95.74%
From layer 2:
Accuracy: 95.75%
From layer 3:
Accuracy: 95.69%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 89.78%
From layer 1:
Accuracy: 96.32%
From layer 2:
Accuracy: 96.47%
From layer 3:
Accuracy: 96.17%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.6625247 0.4316801 0.3392229 0.3603861]
Correct: [66.41666667 88.75       92.33333333 91.66666667]%
Cross Entropy Loss: [1.0696148  0.15356013 0.16448016 0.16801438]
Correct: [83.33333333 95.41666667 95.66666667 95.5       ]%
Cross Entropy Loss: [1.5793362  0.17409055 0.1730139  0.17770503]
Correct: [81.08333333 95.         95.33333333 95.75      ]%
Cross Entropy Loss: [1.3155136  0.16008207 0.16914228 0.1606429 ]
Correct: [85.5        95.41666667 95.5        95.83333333]%
Cross Entropy Loss: [1.4178445  0.14877504 0.16295645 0.16942309]
Correct: [85.91666667 95.83333333 95.58333333 95.91666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 87.52%
From layer 1:
Accuracy: 95.72%
From layer 2:
Accuracy: 95.47%
From layer 3:
Accuracy: 95.45%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 88.10%
From layer 1:
Accuracy: 96.65%
From layer 2:
Accuracy: 96.28%
From layer 3:
Accuracy: 95.97%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.376453   0.42517838 0.35822278 0.33904296]
Correct: [67.16666667 89.66666667 90.25       91.5       ]%
Cross Entropy Loss: [1.3644907  0.19069616 0.18601438 0.18697791]
Correct: [82.33333333 94.41666667 95.08333333 94.75      ]%
Cross Entropy Loss: [1.5174477  0.15634212 0.16194287 0.16461536]
Correct: [81.91666667 95.16666667 95.91666667 95.75      ]%
Cross Entropy Loss: [1.3004358  0.14357959 0.15617082 0.15745345]
Correct: [85.75       96.08333333 96.         95.75      ]%
Cross Entropy Loss: [1.1230577  0.15044114 0.16561505 0.17040502]
Correct: [85.16666667 95.91666667 95.66666667 95.66666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 87.96%
From layer 1:
Accuracy: 95.33%
From layer 2:
Accuracy: 95.40%
From layer 3:
Accuracy: 95.34%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 88.35%
From layer 1:
Accuracy: 96.65%
From layer 2:
Accuracy: 96.27%
From layer 3:
Accuracy: 96.05%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.0279756  0.46350527 0.37065417 0.36309353]
Correct: [66.25       87.66666667 91.         91.16666667]%
Cross Entropy Loss: [1.0403515  0.15943103 0.14953628 0.14353053]
Correct: [84.83333333 95.58333333 95.83333333 95.83333333]%
Cross Entropy Loss: [1.2071066  0.18711491 0.20832297 0.21060458]
Correct: [84.75       94.83333333 94.75       94.83333333]%
Cross Entropy Loss: [1.4031492  0.15133992 0.155506   0.1607064 ]
Correct: [83.66666667 96.25       95.66666667 95.5       ]%
Cross Entropy Loss: [1.0277554  0.11540209 0.11871958 0.12798412]
Correct: [88.         96.5        96.08333333 96.        ]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 82.82%
From layer 1:
Accuracy: 95.08%
From layer 2:
Accuracy: 94.26%
From layer 3:
Accuracy: 94.41%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 88.68%
From layer 1:
Accuracy: 96.47%
From layer 2:
Accuracy: 96.20%
From layer 3:
Accuracy: 95.78%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7167056  0.44559088 0.37563068 0.37984776]
Correct: [70.08333333 89.66666667 90.91666667 91.16666667]%
Cross Entropy Loss: [1.6023502  0.16663647 0.17960748 0.18191728]
Correct: [81.         94.75       95.08333333 95.25      ]%
Cross Entropy Loss: [1.4329418  0.18200654 0.19142093 0.1874464 ]
Correct: [82.66666667 95.25       95.41666667 95.41666667]%
Cross Entropy Loss: [0.9958844  0.14487785 0.1586114  0.16825038]
Correct: [87.58333333 96.33333333 96.         96.16666667]%
Cross Entropy Loss: [1.414678   0.1706594  0.17077143 0.18180208]
Correct: [85.25       95.25       95.         94.41666667]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 83.47%
From layer 1:
Accuracy: 95.44%
From layer 2:
Accuracy: 95.67%
From layer 3:
Accuracy: 95.50%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 86.95%
From layer 1:
Accuracy: 96.48%
From layer 2:
Accuracy: 96.48%
From layer 3:
Accuracy: 96.22%
  0%|          | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7415905  0.46612364 0.3127798  0.3401905 ]
Correct: [67.83333333 88.66666667 91.41666667 92.        ]%
Cross Entropy Loss: [1.170493   0.15758887 0.15697098 0.1570035 ]
Correct: [85.08333333 95.58333333 96.08333333 96.        ]%
Cross Entropy Loss: [1.1175559  0.17982589 0.1780278  0.1790441 ]
Correct: [85.16666667 95.         95.66666667 95.91666667]%
Cross Entropy Loss: [1.1843574  0.17871569 0.18905477 0.18735561]
Correct: [84.58333333 94.75       95.33333333 94.58333333]%
Cross Entropy Loss: [1.1208203  0.14442256 0.153645   0.15707959]
Correct: [86.5  96.   96.   95.75]%
  0%|          | 0/79 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 87.83%
From layer 1:
Accuracy: 95.43%
From layer 2:
Accuracy: 95.71%
From layer 3:
Accuracy: 95.63%
  0%|          | 0/47 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 90.67%
From layer 1:
Accuracy: 96.35%
From layer 2:
Accuracy: 96.50%
From layer 3:
Accuracy: 96.22%
In [ ]:
print(f'Accuracy of last quarter: {100*acc[-len(acc)//4:].mean(axis=0)}%')
plt.figure()
for i in range(acc.shape[1]):
    plt.plot(np.asarray(acc)[:,i]*100, color=color_list[i])
plt.ylabel('Accuracy [%]')
plt.xlabel('Training Step [x500]')
labels = ['From Inputs directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
plt.legend(labels)
plt.ylim([90, 100])
plt.figure()
for i in range(losses_out.shape[1]):
    plt.plot(np.arange(len(losses_out))*args.batch_size/len(train_loader), savgol_filter(losses_out[:,i], 19, 1), label=labels[i], color=color_list[i])
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Epoch')
plt.ylim([0, 0.6])
plt.legend();
Accuracy of last quarter: [86.625  95.9375 95.75   95.75  ]%

Test Accuracy on Test Set¶

In [ ]:
def get_accuracy(dataloader, out_projs, cat=False):
    correct = torch.zeros(len(out_projs))
    for out_proj in out_projs:
        out_proj.eval()
    SNN.eval()
    total = 0
    pred_matrix = torch.zeros(args.n_outputs, args.n_outputs)
    for idx in trange(0, len(dataloader), args.batch_size):
        for out_proj in out_projs:
            out_proj.reset()
        SNN.reset(0)
        if args.dataset == 'mnist':
            inp, target = dataloader.x[idx:idx+args.batch_size], dataloader.y[idx:idx+args.batch_size]
        else:
            flattenend_indeces = torch.cat(dataloader.target_indeces)
            indeces = flattenend_indeces[idx:idx+args.batch_size]
            until = min(args.batch_size, len(dataloader) - idx)
            inp = torch.stack([torch.tensor(dataloader.data[indeces[i]][0]).view(args.n_time_bins, -1) for i in range(until)])
            target = torch.tensor([dataloader.data[indeces[i]][1] for i in range(until)])
        logits = len(out_projs)*[torch.zeros((inp.shape[0],args.n_outputs))]
        for step in range(inp.shape[1]):
            data_step = inp[:,step].float().to(device)
            spk_step, _, _ = SNN(data_step, 0)
            spk_step = [data_step, *spk_step]
            for i, out_proj in enumerate(out_projs):
                if cat:
                    _, mem = out_proj(torch.cat(spk_step[:i+1], dim=-1), target)
                else:
                    _, mem = out_proj(spk_step[i], target)
                if step == args.n_time_bins-1:
                    logits[i] = mem
        total += inp.shape[0]
        for i, logit in enumerate(logits):
            pred = logit.argmax(axis=-1)
            correct[i] += int((pred == target).sum())
        # for the last layer create the prediction matrix
        for j in range(pred.shape[0]):
            pred_matrix[int(target[j]), int(pred[j])] += 1
    correct /= len(dataloader)
    assert total == len(dataloader)
    print('Directly from inputs:')
    print(f'Accuracy: {100*correct[0]:.2f}%')
    for i in range(len(out_projs)-1):
        print(f'From layer {i+1}:')
        print(f'Accuracy: {100*correct[i+1]:.2f}%')
    return correct, pred_matrix
correct, pred_matrix = get_accuracy(test_loader, out_projs, cat=cat)
plt.imshow(pred_matrix, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(args.n_outputs)])
plt.yticks([i for i in range(args.n_outputs)])
plt.colorbar();
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [4], in <cell line: 47>()
     45         print(f'Accuracy: {100*correct[i+1]:.2f}%')
     46     return correct, pred_matrix
---> 47 correct, pred_matrix = get_accuracy(test_loader, out_projs, cat=cat)
     48 plt.imshow(pred_matrix, origin='lower')
     49 plt.title('Prediction Matrix for the final layer')

NameError: name 'out_projs' is not defined
In [ ]:
# train_accs = torch.stack(train_accs)
# test_accs = torch.stack(test_accs)
print(train_accs.shape)
print(f'Train Accuracy: {100*train_accs.mean(axis=0)}%, Std: {100*train_accs.std(axis=0)}%')
print(f'Test Accuracy: {100*test_accs.mean(axis=0)}%, Std: {100*test_accs.std(axis=0)}%')
# grouped Bar plot the Accuracies of the different layers both during training and testing
sns.set_theme(style="whitegrid")
labels = ['From Inputs Directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
x = torch.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars
fig, ax = plt.subplots()
print(x.shape, train_accs.mean(axis=0).shape)
rects1 = ax.bar(x - width/2, 100*test_accs.mean(axis=0), width, label='Test Accuracy', color=color_list[0])
ax.errorbar(x - width/2, 100*test_accs.mean(axis=0), yerr=100*test_accs.std(axis=0), fmt='none', capsize=6, color=color_list[3])
rects2 = ax.bar(x + width/2, 100*train_accs.mean(axis=0), width, label='Train Accuracy', color=color_list[1])
ax.errorbar(x + width/2, 100*train_accs.mean(axis=0), yerr=100*train_accs.std(axis=0), fmt='none', capsize=6, color=color_list[3])
# remove horizontal lines and spines
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.grid(False)
plt.xticks(np.arange(len(out_projs)), labels, rotation=45)
plt.legend()
plt.ylabel('Accuracy [%]')
plt.ylim([40, 100])
torch.Size([10, 4])
Train Accuracy: tensor([87.7750, 96.5600, 96.3383, 96.0167])%, Std: tensor([3.1242, 0.1910, 0.1147, 0.1507])%
Test Accuracy: tensor([85.7950, 95.3970, 95.3930, 95.3340])%, Std: tensor([2.4968, 0.2636, 0.4586, 0.3697])%
torch.Size([4]) torch.Size([4])
Out[ ]:
(40.0, 100.0)

Few-Shot¶

In [ ]:
n_repeats = 10
fewshot_accuracies = torch.zeros((n_repeats, len(SNN.layers)))
for n in range(n_repeats):
    # Randomly select one sample of each class and save the spiking activity
    SNN.reset(0)
    one_shot_samples = torch.zeros(args.n_outputs, args.n_time_bins, args.n_inputs)
    one_shot_spks = [torch.zeros(args.n_outputs, h) for h  in args.n_hidden]
    k = 20
    for i in range(args.n_outputs):
        for j in range(k):
            img, _ = train_loader2.next_item(i, contrastive=False)
            one_shot_samples[i] = img.squeeze()
            for t in range(args.n_time_bins):
                logits, _, _ = SNN(img[t].float(), 0)
                for idx, log in enumerate(logits):
                    one_shot_spks[idx][i] += log.squeeze()

    def metric(spk, one_shot):
        dists = torch.zeros(spk.shape[0], args.n_outputs)
        for i in range(args.n_outputs):
            one_shot_i = one_shot[i] / one_shot[i].sum()
            dists[:, i] = torch.einsum('bi, i->b' , spk, one_shot_i)
        return dists

    def get_predictions(spks):
        preds = torch.zeros(len(spks), spks[0].shape[0])
        # for each layer get the prediction
        for i in range(len(spks)):
            dists = metric(spks[i], one_shot_spks[i])
            preds[i] = dists.argmax(axis=-1)
        return preds

    batch = int(len(test_loader)/10)
    correct_oneshot = torch.zeros(len(SNN.layers))
    SNN.eval()
    pred_matrix_oneshot = torch.zeros(args.n_outputs, args.n_outputs)
    for idx in trange(0, len(test_loader), batch):
        SNN.reset(0)
        if args.dataset == 'mnist':
            inp, target = test_loader.x[idx:idx+batch], test_loader.y[idx:idx+batch]
        else:
            until = min(batch, len(test_loader.data) - idx)
            inp = torch.stack([torch.tensor(test_loader.data[idx+i][0]).view(args.n_time_bins, -1) for i in range(until)])
            target = torch.tensor([test_loader.data[idx+i][1] for i in range(until)])
        logits = [torch.zeros(inp.shape[0], h) for h  in args.n_hidden]
        for step in range(inp.shape[1]):
            data_step = inp[:,step].float().to(device)
            spk_step, _, _ = SNN(data_step, 0)
            for logidx in range(len(spk_step)):
                logits[logidx] += spk_step[logidx]
        preds = get_predictions(logits)
        for i in range(preds.shape[0]):
            correct_oneshot[i] += int((preds[i] == target).sum())
        # for the last layer create the prediction matrix
        for j in range(preds.shape[1]):
            pred_matrix_oneshot[int(target[j]), int(preds[-1, j])] += 1
    correct_oneshot /= len(test_loader)
    for i in range(len(SNN.layers)):
        print(f'From layer {i+1}:')
        print(f'Accuracy: {100*correct_oneshot[i]:.2f}%')
    fewshot_accuracies[n] = correct_oneshot
    plt.imshow(pred_matrix_oneshot, origin='lower')
    plt.title('Prediction Matrix for the final layer')
    plt.xlabel('Prediction')
    plt.ylabel('Target')
    plt.xticks([i for i in range(args.n_outputs)])
    plt.yticks([i for i in range(args.n_outputs)])
    plt.colorbar();
    plt.show()
    print(f'Accuracy per Label: {100*pred_matrix_oneshot.diag()/pred_matrix_oneshot.sum(axis=1)}%') # correct axis?
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 93.77%
From layer 2:
Accuracy: 95.30%
From layer 3:
Accuracy: 95.57%
/tmp/ipykernel_1261198/1569276281.py:68: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.
  plt.colorbar();
Accuracy per Label: tensor([99.1443, 98.5232, 95.1907, 94.0629, 96.9880, 94.7781, 97.1899, 92.6627,
        94.4107, 91.6963])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 94.05%
From layer 2:
Accuracy: 95.18%
From layer 3:
Accuracy: 95.33%
Accuracy per Label: tensor([99.2665, 98.6287, 93.2007, 94.0629, 96.8675, 94.7781, 97.3269, 93.0178,
        95.0182, 91.8149])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 94.40%
From layer 2:
Accuracy: 95.32%
From layer 3:
Accuracy: 95.59%
Accuracy per Label: tensor([99.0220, 98.6287, 95.2460, 93.8300, 96.7470, 94.3864, 97.2584, 92.8994,
        95.2612, 91.4591])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 94.08%
From layer 2:
Accuracy: 95.15%
From layer 3:
Accuracy: 95.47%
Accuracy per Label: tensor([99.0220, 98.5232, 94.6932, 94.0629, 96.8675, 94.9086, 96.6415, 93.1361,
        95.0182, 91.5777])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 93.97%
From layer 2:
Accuracy: 95.18%
From layer 3:
Accuracy: 95.38%
Accuracy per Label: tensor([99.0220, 98.5232, 93.3112, 95.1106, 96.7470, 93.6031, 97.4640, 93.7278,
        95.1397, 91.5777])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 93.83%
From layer 2:
Accuracy: 95.11%
From layer 3:
Accuracy: 95.41%
Accuracy per Label: tensor([98.8998, 98.7342, 94.2510, 93.9464, 96.7470, 94.9086, 96.9842, 92.8994,
        94.6537, 91.9336])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 94.25%
From layer 2:
Accuracy: 95.25%
From layer 3:
Accuracy: 95.68%
Accuracy per Label: tensor([98.4108, 98.6287, 95.0249, 94.1793, 96.8675, 94.9086, 97.4640, 94.3195,
        94.5322, 91.5777])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 93.02%
From layer 2:
Accuracy: 94.84%
From layer 3:
Accuracy: 95.23%
Accuracy per Label: tensor([99.1443, 98.4177, 92.9243, 93.8300, 96.6265, 94.9086, 97.1899, 93.2544,
        95.6258, 91.3405])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 93.58%
From layer 2:
Accuracy: 95.04%
From layer 3:
Accuracy: 95.32%
Accuracy per Label: tensor([99.1443, 98.5232, 93.4771, 94.2957, 96.7470, 95.1697, 97.1899, 93.1361,
        94.2892, 91.6963])%
  0%|          | 0/10 [00:00<?, ?it/s]
From layer 1:
Accuracy: 94.04%
From layer 2:
Accuracy: 95.08%
From layer 3:
Accuracy: 95.38%
Accuracy per Label: tensor([98.8998, 98.6287, 93.8640, 93.7136, 97.2289, 94.9086, 97.0528, 93.7278,
        94.6537, 91.3405])%
In [ ]:
# Boxplot of the accuracies
plt.figure()
sns.set_style("whitegrid")
g = sns.boxplot(data=fewshot_accuracies*100)
# remove left spines
sns.despine(left=True)
plt.xticks(np.arange(len(SNN.layers)), [f'Layer {i+1}' for i in range(len(SNN.layers))])
plt.ylabel('Few-Shot Test Accuracy [%]')
plt.ylim([90, 100])
print(f'Average Accuracy: {100*fewshot_accuracies.mean(axis=0)}%')
print(f'Maximum Accuracy: {fewshot_accuracies.max(axis=0)}%')
Average Accuracy: tensor([93.8990, 95.1450, 95.4360])%
Maximum Accuracy: torch.return_types.max(
values=tensor([0.9440, 0.9532, 0.9568]),
indices=tensor([2, 2, 6]))%